package com.zhl.gateway.filter;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.zhl.common.vo.Result;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter;
import org.springframework.cloud.gateway.route.Route;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.stream.Stream;

/**
 * @author 凌晨
 * @Title: RateCheckGatewayFilterFactory
 * @Description 限流gateway限制
 * @date： 2020/10/16 21:39
 * @version： V1.0
 */
@Component
@Slf4j
public class RateCheckGatewayFilterFactory extends AbstractGatewayFilterFactory<RateCheckGatewayFilterFactory.Config> {
    private final RateLimiter defaultRateLimiter;
    private final KeyResolver defaultKeyResolver;


    public RateCheckGatewayFilterFactory(RateLimiter defaultRateLimiter, KeyResolver defaultKeyResolver) {
        //必须要这一步
        super(RateCheckGatewayFilterFactory.Config.class);
        this.defaultRateLimiter = defaultRateLimiter;
        this.defaultKeyResolver = defaultKeyResolver;
    }

    public RateLimiter getDefaultRateLimiter() {
        return defaultRateLimiter;
    }

    public KeyResolver getDefaultKeyResolver() {
        return defaultKeyResolver;
    }

    @SuppressWarnings("unchecked")
    @Override
    public GatewayFilter apply(Config config) {
        log.info("类:" + this.getClass().getSimpleName() + "今日过滤器");
        KeyResolver resolver = (config.keyResolver == null) ? defaultKeyResolver : config.keyResolver;
        RateLimiter<Object> limiter = (config.rateLimiter == null) ? defaultRateLimiter : config.rateLimiter;
        log.info("配置说明resolver:" + JSONObject.toJSONString(Stream.of(resolver).toArray()));
        log.info("配置说明limiter:" + JSONObject.toJSONString(limiter));
        return (exchange, chain) -> {
            Route route = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
            log.info("配置说明内部route:" + JSONObject.toJSONString(route));
            return resolver.resolve(exchange).flatMap(key ->
                    limiter.isAllowed(route.getId(), key).flatMap(response -> {
                        log.info("配置说明内部" + route.getId() + "___" + key);
                        for (Map.Entry<String, String> header : response.getHeaders().entrySet()) {
                            exchange.getResponse().getHeaders().add(header.getKey(), header.getValue());
                        }
                        if (response.isAllowed()) {
                            return chain.filter(exchange);
                        }
                        ServerHttpResponse rs = exchange.getResponse();
                        byte[] datas = JSON.toJSONString(new Result<>().error500("TOO_MANY_REQUESTS"))
                                .getBytes(StandardCharsets.UTF_8);
                        DataBuffer buffer = rs.bufferFactory().wrap(datas);
                        rs.setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
                        rs.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);
                        return rs.writeWith(Mono.just(buffer));
                    }));
        };
    }

    /**
     * @Description: 基础的参数配置【此配置会通过工厂进行参数绑定】
     * @Param:
     * @return:
     * @Author: peikunkun
     * @Date: 2019/8/21 0021 下午 6:10
     */
    public static class Config {

        private KeyResolver keyResolver;
        private RateLimiter rateLimiter;
        private HttpStatus statusCode = HttpStatus.TOO_MANY_REQUESTS;

        public KeyResolver getKeyResolver() {
            return keyResolver;
        }

        public Config setKeyResolver(KeyResolver keyResolver) {
            this.keyResolver = keyResolver;
            return this;
        }

        public RateLimiter getRateLimiter() {
            return rateLimiter;
        }

        public Config setRateLimiter(RateLimiter rateLimiter) {
            this.rateLimiter = rateLimiter;
            return this;
        }

        public HttpStatus getStatusCode() {
            return statusCode;
        }

        public Config setStatusCode(HttpStatus statusCode) {
            this.statusCode = statusCode;
            return this;
        }
    }
}
